package ntp

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"net"
	"strings"
	"time"
)

const (
	UNIX_STA_TIMESTAMP = 2208988800
)

/**
NTP协议 http://www.ntp.org/documentation.html
NTP报文-报文头总长度48字节
*/
type Ntp struct {
	Li                  uint8  //2 bits 闰秒标志
	Vn                  uint8  //3 bits NTP版本号，3或4
	Mode                uint8  //3 bits 3-客户端请求,4-服务器响应
	Stratum             uint8  //服务器时钟的层级,1最高
	Poll                uint8  //
	Precision           uint8  //服务器的时钟精度
	RootDelay           int32  //服务器与主时钟源的最大往返通讯时延/秒
	RootDispersion      int32  //服务器相对主时钟源的最大误差/秒
	ReferenceIdentifier int32  //主时钟源标识
	ReferenceTimestamp  uint64 //服务器时钟最后一次校准的时间
	OriginateTimestamp  uint64 //客户向服务器发起请求的时间
	ReceiveTimestamp    uint64 //服务器收到客户请求的时间
	TransmitTimestamp   uint64 //服务器向客户发时间戳的时间
}

//
func NewNtp() (p *Ntp) {
	//其他参数通常都是服务器返回的
	p = &Ntp{Li: 0, Vn: 3, Mode: 3, Stratum: 0}
	return p
}

/**
转换为NTP协议报文
*/
func (this *Ntp) Encode() []byte {
	//注意网络上使用的是大端字节排序
	buf := &bytes.Buffer{}
	head := (this.Li << 6) | (this.Vn << 3) | ((this.Mode << 5) >> 5)
	binary.Write(buf, binary.BigEndian, uint8(head))
	binary.Write(buf, binary.BigEndian, this.Stratum)
	binary.Write(buf, binary.BigEndian, this.Poll)
	binary.Write(buf, binary.BigEndian, this.Precision)
	//写入其他字节数据
	binary.Write(buf, binary.BigEndian, this.RootDelay)
	binary.Write(buf, binary.BigEndian, this.RootDispersion)
	binary.Write(buf, binary.BigEndian, this.ReferenceIdentifier)
	binary.Write(buf, binary.BigEndian, this.ReferenceTimestamp)
	binary.Write(buf, binary.BigEndian, this.OriginateTimestamp)
	binary.Write(buf, binary.BigEndian, this.ReceiveTimestamp)
	binary.Write(buf, binary.BigEndian, this.TransmitTimestamp)
	//[27 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
	return buf.Bytes()
}

//解析服务器回复的ntp报文bf
//useUnixSec指示是否将服务器回复的ntp时戳转换为unix时戳
func (this *Ntp) Decode(buf []byte, useUnixSec bool) {
	br := bytes.NewReader(buf)
	binary.Read(br, binary.BigEndian, &this.Mode)
	this.Li = this.Mode >> 6
	this.Vn = (this.Mode << 2) >> 5
	this.Mode = (this.Mode << 5) >> 5

	binary.Read(br, binary.BigEndian, &this.Stratum)
	binary.Read(br, binary.BigEndian, &this.Poll)
	binary.Read(br, binary.BigEndian, &this.Precision)
	binary.Read(br, binary.BigEndian, &this.RootDelay)
	binary.Read(br, binary.BigEndian, &this.RootDispersion)
	binary.Read(br, binary.BigEndian, &this.ReferenceIdentifier)
	binary.Read(br, binary.BigEndian, &this.ReferenceTimestamp)
	binary.Read(br, binary.BigEndian, &this.OriginateTimestamp)
	binary.Read(br, binary.BigEndian, &this.ReceiveTimestamp)
	binary.Read(br, binary.BigEndian, &this.TransmitTimestamp)

	//转换为unix时间戳,左32位为时戳的整数秒部分
	if useUnixSec {
		this.ReferenceTimestamp = (this.ReceiveTimestamp >> 32) - UNIX_STA_TIMESTAMP
		if this.OriginateTimestamp > 0 {
			this.OriginateTimestamp = (this.OriginateTimestamp >> 32) - UNIX_STA_TIMESTAMP
		}
		this.ReceiveTimestamp = (this.ReceiveTimestamp >> 32) - UNIX_STA_TIMESTAMP
		this.TransmitTimestamp = (this.TransmitTimestamp >> 32) - UNIX_STA_TIMESTAMP
	}
}

//
func (this *Ntp) ToString() string {
	return fmt.Sprintf("%+v", this)
}

//获取NTP服务器时戳及往返耗时
func TimeOfNtpServer(ntp_server_addr string) (now time.Time, cast time.Duration, err error) {
	if i := strings.Index(ntp_server_addr, ":"); i < 1 {
		ntp_server_addr += ":123" //默认ntp端口
	}
	var (
		conn     net.Conn
		req, res []byte = NewNtp().Encode(), make([]byte, 1024)
		t0, t1   time.Time
		n        int
	)
	conn, err = net.Dial("udp", ntp_server_addr)
	if err != nil {
		return
	}
	defer conn.Close()
	ntpo := NewNtp()
	t0 = time.Now()
	conn.SetWriteDeadline(time.Now().Add(time.Second * 2))
	if n, err = conn.Write(req); err != nil {
		return
	}
	// fmt.Println(hex.EncodeToString(req))
	conn.SetReadDeadline(time.Now().Add(time.Second * 5))
	if n, err = conn.Read(res); err != nil {
		return
	}
	t1 = time.Now()
	cast = t1.Sub(t0)
	// fmt.Println(hex.EncodeToString(buffer[:ret]))
	if n > 0 {
		ntpo.Decode(res, true)
		now = time.Unix(int64(ntpo.TransmitTimestamp), 0)
	}
	return
}

//创建一个ver3的ntp请求报文
func NewNtpV3Req() []byte {
	req := make([]byte, 48, 48)
	req[0] = 0x1b
	return req
}

//从ntp响应报文中提取服务器回复时戳TransmitTimestamp
func GetT2T3Unix(ntpRes []byte) (int64, int64) {
	t2 := binary.BigEndian.Uint64(ntpRes[32:])
	t3 := binary.BigEndian.Uint64(ntpRes[40:])
	t2 = (t2 >> 32) - UNIX_STA_TIMESTAMP
	t3 = (t3 >> 32) - UNIX_STA_TIMESTAMP
	return int64(t2), int64(t3)
}

//从ntp响应报文中提取服务器回复时戳TransmitTimestamp
func ParseTransmitTimestamp(ntpRes []byte) time.Time {
	ts := binary.BigEndian.Uint64(ntpRes[40:])
	ts = (ts >> 32) - UNIX_STA_TIMESTAMP
	return time.Unix(int64(ts), 0)
}

//ntp快速查询
func SimpleNtpQuery(ntp_server_addr string) (now time.Time, cast time.Duration, err error) {
	if i := strings.Index(ntp_server_addr, ":"); i < 1 {
		ntp_server_addr += ":123" //默认ntp端口
	}
	var (
		conn     net.Conn
		req, res []byte = NewNtpV3Req(), make([]byte, 1024)
		t0, t1   time.Time
		n        int
	)
	conn, err = net.Dial("udp", ntp_server_addr)
	if err != nil {
		return
	}
	defer conn.Close()
	t0 = time.Now()
	conn.SetWriteDeadline(time.Now().Add(time.Second * 2))
	if n, err = conn.Write(req); err != nil {
		return
	}
	// fmt.Println(hex.EncodeToString(req))
	conn.SetReadDeadline(time.Now().Add(time.Second * 5))
	if n, err = conn.Read(res); err != nil {
		return
	}
	// fmt.Println(hex.EncodeToString(res[:n]))
	t1 = time.Now()
	cast = t1.Sub(t0)
	if n >= 48 {
		now = ParseTransmitTimestamp(res)
	}
	d, t := Calc(res, t0.Unix(), t1.Unix())
	fmt.Println("calc", d, t)
	return
}

//NTP时戳转为UNIX时戳
func NtpTimestamp2Unix(ts uint64) int64 {
	ts = ts>>32 - UNIX_STA_TIMESTAMP
	return int64(ts)
}

//UNIX时戳转为NTP时戳
func UnixTimestamp2Ntp(uts int64) uint64 {
	ts := uint64(uts)
	ts = ts + UNIX_STA_TIMESTAMP
	ts = ts << 32
	return ts
}

//根据NTP响应及收发时戳计算
func Calc(res []byte, reqTs, resTs int64) (d int64, t int64) {
	t1, t4 := reqTs, resTs
	t2, t3 := GetT2T3Unix(res)
	d = (t4 - t1) - (t3 - t2)
	t = ((t2 - t1) - (t3 - t4)) / 2
	return
}
